Synphony

Deep Learning Final Project - MSDS Spring Module 2 - 2025

Aditi Puttur & Emma Juan

1. Data Preprocessing¶

In [2]:
import pandas as pd
import numpy as np

import os
import json

from tqdm import tqdm

import re
import unicodedata

import warnings
warnings.filterwarnings("ignore")

from miditok import REMI, TokenizerConfig, TokSequence
from miditoolkit import MidiFile
from symusic import Score

os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

import math
from typing import Optional

import traceback

Loading the data¶

LMD: Midi Files¶

In [3]:
# Open and read the JSON file
with open('data/LMD/md5_to_paths.json', 'r') as file:
    md5_to_paths = json.load(file)
In [7]:
md5_to_paths['1c83fc02b8c57fbc2605900bb31793fb']
Out[7]:
['E/Exaltasamba - Megastar.mid',
 'Midis Samba e Pagode/Exaltasamba - Megastar.mid',
 'Midis Samba e Pagode/Exaltasamba - Megastar.mid']
In [9]:
lmd_catalog = []

for dirpath, dirnames, filenames in os.walk('data/LMD/lmd_matched'):
    for file in filenames:
        full_path = os.path.join(dirpath, file)
        if full_path.endswith('.mid'):
            lmd_catalog.append(full_path)
In [10]:
lmd_catalog.sort()
lmd_catalog[:10]
Out[10]:
['data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/1d9d16a9da90c090809c153754823c2b.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/5dd29e99ed7bd3cc0c5177a6e9de22ea.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/b97c529ab9ef783a849b896816001748.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/dac3cdd0db6341d8dc14641e44ed0d44.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/05f21994c71a5f881e64f45c8d706165.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/10288ea8e07b70c17f872fda82b94330.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/6304d2bba4282f8bd74322828c30f0c7.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/c24989559d170135b9c6546d1d2df20b.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/ddb6a3db65461dca1a43de72f5375d8b.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/dfea6fd75926c571a87db789280d059d.mid']
In [6]:
len(lmd_catalog)
Out[6]:
116189
In [7]:
lmd_catalog_all = {'path': [],
                   'MSD_name': [],
                   'LMD_name': []}

lmd_catalog_all['path'] = lmd_catalog
lmd_catalog_all['MSD_name'] = [path.split('/')[-2] for path in lmd_catalog]
lmd_catalog_all['LMD_name'] = [path.split('/')[-1].split('.')[-2] for path in lmd_catalog]

lmd_df = pd.DataFrame(lmd_catalog_all)
lmd_df
Out[7]:
path MSD_name LMD_name
0 data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... TRAAAGR128F425B14B 1d9d16a9da90c090809c153754823c2b
1 data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... TRAAAGR128F425B14B 5dd29e99ed7bd3cc0c5177a6e9de22ea
2 data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... TRAAAGR128F425B14B b97c529ab9ef783a849b896816001748
3 data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... TRAAAGR128F425B14B dac3cdd0db6341d8dc14641e44ed0d44
4 data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/... TRAAAZF12903CCCF6B 05f21994c71a5f881e64f45c8d706165
... ... ... ...
116184 data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/... TRZZZTN128EF35C42F 165e156e5192569e41dc8390b80a1465
116185 data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/... TRZZZTN128EF35C42F 87e403b5fcb06718767aee0a9386f86c
116186 data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/... TRZZZTN128EF35C42F c56e00ecc890dfdfbdd551cb9ea15ca5
116187 data/LMD/lmd_matched/Z/Z/Z/TRZZZYV128F92E996D/... TRZZZYV128F92E996D 1b966417a9aa703873c5fa1cfe18da32
116188 data/LMD/lmd_matched/Z/Z/Z/TRZZZYV128F92E996D/... TRZZZYV128F92E996D 3bcd7e0cc20adcc8dc3e912623bb0e1b

116189 rows × 3 columns

In [8]:
lmd_df["MSD_name"].nunique()
Out[8]:
31034

Lackh MIDI Dataset (only tracks with matching metadata files) → 31,034 tracks / 116,189 MIDI files.

LMD-matched metadata (MillionSongDataset): The Metadata¶

We will extract title, artist and year from the metadata and add it to our dataset.

In [11]:
import hdf5_getters
In [12]:
msd_catalog = []
titles = []
artists = []
releases = []
years = []

for dirpath, dirnames, filenames in tqdm(os.walk('data/LMD-matched-MSD')):
    for file in filenames:
        full_path = os.path.join(dirpath, file)
        if full_path.endswith('.h5'):

            # Append the path to the list
            msd_catalog.append(full_path)

            # Get the metadata
            h5 = hdf5_getters.open_h5_file_read(full_path)
            titles.append(hdf5_getters.get_title(h5))
            artists.append(hdf5_getters.get_artist_name(h5))
            releases.append(hdf5_getters.get_release(h5))
            years.append(hdf5_getters.get_year(h5))
            # danceability = hdf5_getters.get_danceability(h5)
            # get_energy = hdf5_getters.get_energy(h5)
15298it [07:23, 34.52it/s]
In [13]:
msd_catalog[:10]
Out[13]:
['data/LMD-matched-MSD/R/R/U/TRRRUFD12903CD7092.h5',
 'data/LMD-matched-MSD/R/R/U/TRRRUTV12903CEA11B.h5',
 'data/LMD-matched-MSD/R/R/U/TRRRUJO128E07813E7.h5',
 'data/LMD-matched-MSD/R/R/I/TRRRIYO128F428CF6F.h5',
 'data/LMD-matched-MSD/R/R/I/TRRRILO128F422FFED.h5',
 'data/LMD-matched-MSD/R/R/I/TRRRIVC12903CA6C5A.h5',
 'data/LMD-matched-MSD/R/R/I/TRRRILD128F92CB682.h5',
 'data/LMD-matched-MSD/R/R/I/TRRRION128F145EBB7.h5',
 'data/LMD-matched-MSD/R/R/N/TRRRNPV128F42AAA55.h5',
 'data/LMD-matched-MSD/R/R/N/TRRRNGS12903CD16D9.h5']
In [12]:
len(msd_catalog)
Out[12]:
31034
In [13]:
len(msd_catalog) == lmd_df["MSD_name"].nunique()
Out[13]:
True
In [14]:
titles[:5]
Out[14]:
[b'Wastelands',
 b'Runaway',
 b'Have You Met Miss Jones? (Swing When Version)',
 b'Goodbye',
 b'La Colegiala']
In [15]:
artists[:5]
Out[15]:
[b'Hawkwind',
 b'Del Shannon',
 b'Robbie Williams',
 b'Volebeats',
 b'Rodolfo Y Su Tipica Ra7']
In [16]:
years[:5]
Out[16]:
[1994, 1961, 2001, 0, 1997]
In [17]:
titles = [title.decode('utf-8') for title in titles]
artists = [artist.decode('utf-8') for artist in artists]
In [18]:
msd_catalog_all = {'path': [],
                   'MSD_name': [],
                   'title': [],
                   'artist': [],
                   'year': []}

msd_catalog_all['path'] = msd_catalog
msd_catalog_all['title'] = titles
msd_catalog_all['artist'] = artists
msd_catalog_all['year'] = years
msd_catalog_all['MSD_name'] = [path.split('/')[-1].split('.')[-2] for path in msd_catalog]

msd_df = pd.DataFrame(msd_catalog_all)
msd_df
Out[18]:
path MSD_name title artist year
0 data/LMD-matched-MSD/R/R/U/TRRRUFD12903CD7092.h5 TRRRUFD12903CD7092 Wastelands Hawkwind 1994
1 data/LMD-matched-MSD/R/R/U/TRRRUTV12903CEA11B.h5 TRRRUTV12903CEA11B Runaway Del Shannon 1961
2 data/LMD-matched-MSD/R/R/U/TRRRUJO128E07813E7.h5 TRRRUJO128E07813E7 Have You Met Miss Jones? (Swing When Version) Robbie Williams 2001
3 data/LMD-matched-MSD/R/R/I/TRRRIYO128F428CF6F.h5 TRRRIYO128F428CF6F Goodbye Volebeats 0
4 data/LMD-matched-MSD/R/R/I/TRRRILO128F422FFED.h5 TRRRILO128F422FFED La Colegiala Rodolfo Y Su Tipica Ra7 1997
... ... ... ... ... ...
31029 data/LMD-matched-MSD/W/W/Y/TRWWYHD12903CC42B1.h5 TRWWYHD12903CC42B1 Gethsemane (I Only Want to Say) (Live-LP Version) Michael Crawford 0
31030 data/LMD-matched-MSD/W/W/Y/TRWWYNJ128F426541F.h5 TRWWYNJ128F426541F Cold Feelings Social Distortion 1992
31031 data/LMD-matched-MSD/W/W/P/TRWWPSV128F4244C71.h5 TRWWPSV128F4244C71 Ases Death At Vance 2001
31032 data/LMD-matched-MSD/W/W/P/TRWWPBK128F42911E9.h5 TRWWPBK128F42911E9 Drowned Maid Amorphis 1993
31033 data/LMD-matched-MSD/W/W/W/TRWWWUT128F9364D1A.h5 TRWWWUT128F9364D1A Ting-A-Ling A Balladeer 0

31034 rows × 5 columns

In [19]:
msd_df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 31034 entries, 0 to 31033
Data columns (total 5 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   path      31034 non-null  object
 1   MSD_name  31034 non-null  object
 2   title     31034 non-null  object
 3   artist    31034 non-null  object
 4   year      31034 non-null  int32 
dtypes: int32(1), object(4)
memory usage: 1.1+ MB

Million Song Dataset (only the files for the matched LMD dataset) → 31,034 metadata files (.h5 format)

tagtraum: Adding Genre Tags¶

In [20]:
tagtraum = {'MSD_name': [],
            'genre': []}

with open("data/tagtraum/msd_tagtraum_cd2c.cls", "r") as file:
    lines = file.readlines()
    for line in lines:
        if not line.startswith('#'):
            track, genre = line.strip().split('\t')
            tagtraum['MSD_name'].append(track)
            tagtraum['genre'].append(genre)
In [21]:
tagtraum_df = pd.DataFrame(tagtraum)
tagtraum_df
Out[21]:
MSD_name genre
0 TRAAAAK128F9318786 Rock
1 TRAAAAW128F429D538 Rap
2 TRAAADJ128F4287B47 Rock
3 TRAAADZ128F9348C2E Latin
4 TRAAAED128E0783FAB Jazz
... ... ...
191396 TRZZZMY128F426D7A2 Reggae
191397 TRZZZRJ128F42819AF Rock
191398 TRZZZUK128F92E3C60 Folk
191399 TRZZZZD128F4236844 Rock
191400 TRZZZZZ12903D05E3A Electronic

191401 rows × 2 columns

In [22]:
tagtraum_df["genre"].unique()
Out[22]:
array(['Rock', 'Rap', 'Latin', 'Jazz', 'Electronic', 'Pop', 'Metal',
       'RnB', 'Country', 'Reggae', 'Blues', 'Folk', 'Punk', 'World',
       'New Age'], dtype=object)

Tagtraum genre tags → 191,401 tags

Creating our dataset: MIDI + Metadata + Genres¶

Midi + Metadata¶

Each track (MSD_name -> track_id) has one metadata file, and different MIDI files (LMD_name -> midi_id) associated with it.

In [23]:
len(lmd_df), len(msd_df)
Out[23]:
(116189, 31034)
In [24]:
lmd_df["MSD_name"].nunique(), len(msd_df)
Out[24]:
(31034, 31034)
In [25]:
dataset = lmd_df.merge(msd_df, how="inner", on="MSD_name", suffixes=('_lmd', '_msd'))
dataset = dataset.rename(columns={"path_lmd": "midi_filepath",
                                  "path_msd": "metadata_filepath",
                                  "MSD_name": "track_id",
                                  "LMD_name": "midi_id"})
dataset = dataset[["track_id", "midi_id", "midi_filepath",
                   "title", "artist", "year"]]
dataset
Out[25]:
track_id midi_id midi_filepath title artist year
0 TRAAAGR128F425B14B 1d9d16a9da90c090809c153754823c2b data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... Into The Nightlife Cyndi Lauper 2008
1 TRAAAGR128F425B14B 5dd29e99ed7bd3cc0c5177a6e9de22ea data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... Into The Nightlife Cyndi Lauper 2008
2 TRAAAGR128F425B14B b97c529ab9ef783a849b896816001748 data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... Into The Nightlife Cyndi Lauper 2008
3 TRAAAGR128F425B14B dac3cdd0db6341d8dc14641e44ed0d44 data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... Into The Nightlife Cyndi Lauper 2008
4 TRAAAZF12903CCCF6B 05f21994c71a5f881e64f45c8d706165 data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/... Break My Stride Matthew Wilder 1983
... ... ... ... ... ... ...
116184 TRZZZTN128EF35C42F 165e156e5192569e41dc8390b80a1465 data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/... Funky Dance Music Vol 1 DJ Rob E 0
116185 TRZZZTN128EF35C42F 87e403b5fcb06718767aee0a9386f86c data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/... Funky Dance Music Vol 1 DJ Rob E 0
116186 TRZZZTN128EF35C42F c56e00ecc890dfdfbdd551cb9ea15ca5 data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/... Funky Dance Music Vol 1 DJ Rob E 0
116187 TRZZZYV128F92E996D 1b966417a9aa703873c5fa1cfe18da32 data/LMD/lmd_matched/Z/Z/Z/TRZZZYV128F92E996D/... Dear Lie TLC 1999
116188 TRZZZYV128F92E996D 3bcd7e0cc20adcc8dc3e912623bb0e1b data/LMD/lmd_matched/Z/Z/Z/TRZZZYV128F92E996D/... Dear Lie TLC 1999

116189 rows × 6 columns

Some tracks have multiple MIDI files. We will only keep one MIDI file per track.

In [26]:
grouped_dataset = dataset.groupby('track_id').first().reset_index()
grouped_dataset = grouped_dataset[['track_id', 'midi_id', 'midi_filepath']]
grouped_dataset = grouped_dataset.merge(
    dataset[
        ['track_id', "title", "artist", "year"]
    ].drop_duplicates(), on='track_id', how='left' )
grouped_dataset = grouped_dataset[["track_id", "midi_id", "midi_filepath",
                                   "title", "artist", "year"]]
grouped_dataset
Out[26]:
track_id midi_id midi_filepath title artist year
0 TRAAAGR128F425B14B 1d9d16a9da90c090809c153754823c2b data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... Into The Nightlife Cyndi Lauper 2008
1 TRAAAZF12903CCCF6B 05f21994c71a5f881e64f45c8d706165 data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/... Break My Stride Matthew Wilder 1983
2 TRAABVM128F92CA9DC 0dd4d2b9fbcf96a0fa363a1918255e58 data/LMD/lmd_matched/A/A/B/TRAABVM128F92CA9DC/... Caught In A Dream Tesla 2004
3 TRAABXH128F42955D6 01ffb8729a2465bfa7f9ba0288c89e24 data/LMD/lmd_matched/A/A/B/TRAABXH128F42955D6/... Keep An Eye On Summer (Album Version) Brian Wilson 1998
4 TRAACQE12903CC706C 1ee7c9ad5f18b2659789d9608c951ca5 data/LMD/lmd_matched/A/A/C/TRAACQE12903CC706C/... Summer Old Man River 2007
... ... ... ... ... ... ...
31029 TRZZYLO12903CAC06C 128551e12d6dec38ad7ce00665c77fe5 data/LMD/lmd_matched/Z/Z/Y/TRZZYLO12903CAC06C/... I've Never Seen The Righteous Forsaken Dallas Holm 0
31030 TRZZYTX128F92EBE33 538838021299e65875a8bec61a87a368 data/LMD/lmd_matched/Z/Z/Y/TRZZYTX128F92EBE33/... I Don't Want To Do It (2009 Digital Remaster) George Harrison 0
31031 TRZZZBU128F426811B 0702ddab7728f7b0e51321d8a0366367 data/LMD/lmd_matched/Z/Z/Z/TRZZZBU128F426811B/... Dame Una Se񡬢 size= Los Iracundos 0
31032 TRZZZTN128EF35C42F 165e156e5192569e41dc8390b80a1465 data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/... Funky Dance Music Vol 1 DJ Rob E 0
31033 TRZZZYV128F92E996D 1b966417a9aa703873c5fa1cfe18da32 data/LMD/lmd_matched/Z/Z/Z/TRZZZYV128F92E996D/... Dear Lie TLC 1999

31034 rows × 6 columns

Adding the genre tags¶

In [27]:
dataset = dataset.merge(tagtraum_df, how="inner", left_on="track_id", right_on="MSD_name")
dataset = dataset.drop(columns=["MSD_name"])
dataset
Out[27]:
track_id midi_id midi_filepath title artist year genre
0 TRAAAGR128F425B14B 1d9d16a9da90c090809c153754823c2b data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... Into The Nightlife Cyndi Lauper 2008 Pop
1 TRAAAGR128F425B14B 5dd29e99ed7bd3cc0c5177a6e9de22ea data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... Into The Nightlife Cyndi Lauper 2008 Pop
2 TRAAAGR128F425B14B b97c529ab9ef783a849b896816001748 data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... Into The Nightlife Cyndi Lauper 2008 Pop
3 TRAAAGR128F425B14B dac3cdd0db6341d8dc14641e44ed0d44 data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... Into The Nightlife Cyndi Lauper 2008 Pop
4 TRAABVM128F92CA9DC 0dd4d2b9fbcf96a0fa363a1918255e58 data/LMD/lmd_matched/A/A/B/TRAABVM128F92CA9DC/... Caught In A Dream Tesla 2004 Rock
... ... ... ... ... ... ... ...
21348 TRZZROL12903CAC4A8 0f0aaf2f90bc66da732f4371e703eae4 data/LMD/lmd_matched/Z/Z/R/TRZZROL12903CAC4A8/... Love Love Amy MacDonald 2010 Pop
21349 TRZZSML12903CBB7BD bc4aae694e7c433a6da16284e52e11be data/LMD/lmd_matched/Z/Z/S/TRZZSML12903CBB7BD/... Airwave (Radio Edit) Rank 1 2000 Electronic
21350 TRZZTHP128F427F139 b085f5c3571f570bdc44fa0c9b6a0672 data/LMD/lmd_matched/Z/Z/T/TRZZTHP128F427F139/... Briaris The Sweetest Ache 1992 Rock
21351 TRZZTHP128F427F139 f10a54a5e8b4d169eec5231bb6b15c94 data/LMD/lmd_matched/Z/Z/T/TRZZTHP128F427F139/... Briaris The Sweetest Ache 1992 Rock
21352 TRZZXJE12903CD1D93 7723a2ff572a0b49f9d0e552313f7db7 data/LMD/lmd_matched/Z/Z/X/TRZZXJE12903CD1D93/... Warm and Tender Love Percy Sledge 1967 RnB

21353 rows × 7 columns

In [28]:
grouped_dataset = grouped_dataset.merge(tagtraum_df, how="inner", left_on="track_id", right_on="MSD_name")
grouped_dataset = grouped_dataset.drop(columns=["MSD_name"])
grouped_dataset
Out[28]:
track_id midi_id midi_filepath title artist year genre
0 TRAAAGR128F425B14B 1d9d16a9da90c090809c153754823c2b data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... Into The Nightlife Cyndi Lauper 2008 Pop
1 TRAABVM128F92CA9DC 0dd4d2b9fbcf96a0fa363a1918255e58 data/LMD/lmd_matched/A/A/B/TRAABVM128F92CA9DC/... Caught In A Dream Tesla 2004 Rock
2 TRAAGMC128F4292D0F 0644195d1a3d14e0a0bd3d8b30dc68da data/LMD/lmd_matched/A/A/G/TRAAGMC128F4292D0F/... My Love (Album Version) LITTLE TEXAS 0 Country
3 TRAANZE128F148BF55 0597bf18743a5aacfedc981eb58c9da9 data/LMD/lmd_matched/A/A/N/TRAANZE128F148BF55/... The Name Of The Game Abba 1977 Pop
4 TRAAPPQ128F14961F5 d39a20f33af4fb6b307529db8cf0cc3f data/LMD/lmd_matched/A/A/P/TRAAPPQ128F14961F5/... Wig The B-52's 1986 Rock
... ... ... ... ... ... ... ...
6175 TRZZQGM128F9311E60 34d27fedd8dca07e36f50d69ba477e5b data/LMD/lmd_matched/Z/Z/Q/TRZZQGM128F9311E60/... Sun Of Jamaica Goombay Dance Band 1991 Pop
6176 TRZZROL12903CAC4A8 0f0aaf2f90bc66da732f4371e703eae4 data/LMD/lmd_matched/Z/Z/R/TRZZROL12903CAC4A8/... Love Love Amy MacDonald 2010 Pop
6177 TRZZSML12903CBB7BD bc4aae694e7c433a6da16284e52e11be data/LMD/lmd_matched/Z/Z/S/TRZZSML12903CBB7BD/... Airwave (Radio Edit) Rank 1 2000 Electronic
6178 TRZZTHP128F427F139 b085f5c3571f570bdc44fa0c9b6a0672 data/LMD/lmd_matched/Z/Z/T/TRZZTHP128F427F139/... Briaris The Sweetest Ache 1992 Rock
6179 TRZZXJE12903CD1D93 7723a2ff572a0b49f9d0e552313f7db7 data/LMD/lmd_matched/Z/Z/X/TRZZXJE12903CD1D93/... Warm and Tender Love Percy Sledge 1967 RnB

6180 rows × 7 columns

When we put the three datasets together, we eneded up with 31,034 data points (MIDI file, metadata file, and genre tag)

Sluggifying our parameters¶

Slug‑safe metadata – ASCII‑safe, ALL_CAPS slugs for 15 genres, 2,956 artists, 60 years (between 1945 - 2010).

In [29]:
genres = dataset["genre"].unique()
artists = dataset["artist"].unique()
years = dataset["year"].unique()
In [30]:
def slug(text: str) -> str:
    """Return an ALL_CAPS alnum/underscore version of `text`."""
    # 1) strip accents → ascii
    text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode()
    # 2) replace non‑alnum with underscore
    text = re.sub(r"[^\w]+", "_", text)
    # 3) collapse multiple underscores and upper‑case
    return re.sub(r"_+", "_", text).strip("_").upper()
In [31]:
genres_slugged = np.array([slug(genre) for genre in genres])
artists_slugged = np.array([slug(artist) for artist in artists])
years = np.array([int(year) for year in years if not pd.isna(year)])
In [32]:
genres = pd.DataFrame({
    'genre': genres,
    'slugged_genre': genres_slugged
})

artists = pd.DataFrame({
    'artist': artists,
    'slugged_artist': artists_slugged
})

years = pd.DataFrame({
    'year': years
})
In [33]:
genres = genres.sort_values(by='genre')
artists = artists.sort_values(by='artist')
years = years.sort_values(by='year')
In [34]:
dataset["slugged_genre"] = dataset["genre"].map(genres.set_index('genre')['slugged_genre'])
dataset["slugged_artist"] = dataset["artist"].map(artists.set_index('artist')['slugged_artist'])

grouped_dataset["slugged_genre"] = grouped_dataset["genre"].map(genres.set_index('genre')['slugged_genre'])
grouped_dataset["slugged_artist"] = grouped_dataset["artist"].map(artists.set_index('artist')['slugged_artist'])

Saving our data¶

Saving the metadata datasets¶

In [35]:
dataset.to_csv("data/metadata.csv", index=False)
In [36]:
grouped_dataset.to_csv("data/grouped_metadata.csv", index=False)

Saving the different parameters to csvs¶

In [37]:
genres.to_csv("data/genres.csv", index=False)
artists.to_csv("data/artists.csv", index=False)
years.to_csv("data/years.csv", index=False)

2. Model Implementation¶

In [85]:
dataset = pd.read_csv("data/metadata.csv")
grouped_dataset = pd.read_csv("data/grouped_metadata.csv")

genres = pd.read_csv("data/genres.csv")
titles = pd.read_csv("data/titles.csv")
artists = pd.read_csv("data/artists.csv")
years = pd.read_csv("data/years.csv")
In [86]:
genres_slugged = genres["slugged_genre"].values
artists_slugged = artists["slugged_artist"].values
years_vals = years["year"].values
In [87]:
# Config whith which the model was trained
# MAX_TOKENS = 512
# BATCH_SIZE = 2

# D_MODEL    = 512
# N_LAYERS   = 6
# N_HEADS    = 8

# New config to try
MAX_TOKENS = 1024
BATCH_SIZE = 8

D_MODEL = 768
N_LAYERS = 8
N_HEADS = 12 # 768 / 12 = 64 per head

Tokenisation¶

Tokenisation converts variable‑length MIDI into a single integer stream compatible with text‑style language modelling, while injecting controllable style cues.

Defining the tokenizer¶

Library: miditok‑REMI with config:

  • use_chords=True
  • use_programs=True
  • 32 velocity bins
  • beat‑resolution {(0‑4):8,(4‑8):4}
  • rests and time‑signatures enabled
In [88]:
config = TokenizerConfig(
    num_velocities=32,
    use_chords=True,
    use_programs=True,
    beat_res={(0,4): 8, (4,8): 4},
    use_rests=True,
    rest_range=(2,8),
    use_time_signatures=True
)

tokenizer = REMI(config)

Adding our special tokens¶

Conditioning We prepend three special tokens per piece: <GENRE_X> <ARTIST_Y> <YEAR_Z> (vocab extended programmatically). Each full sequence ends with .

In [89]:
special_toks = \
    [f"<GENRE_{g}>"  for g in genres_slugged] + \
        [f"<ARTIST_{a}>" for a in artists_slugged] + \
            [f"<YEAR_{y}>"   for y in years_vals]  + \
                ["<EOS>", "<PAD>"]

for tok in special_toks:
    tokenizer.add_to_vocab(tok)

Tokenising: Storing each track as a numpy int32 array.¶

In [90]:
tokenizing = False
In [91]:
# ─── 1. Helpers ──────────────────────────────────────────────────────────
def build_prefix(genre, artist, year, tokenizer):
    """Convert metadata row → list[int] conditioning tokens."""
    genre_tok  = f"<GENRE_{genre}>"
    artist_tok = f"<ARTIST_{artist}>"
    year_tok   = f"<YEAR_{year}>"

    # NOTE: use tokenizer.vocab[...]  (or .token_to_id(...))
    return [
        tokenizer.vocab[genre_tok],
        tokenizer.vocab[artist_tok],
        tokenizer.vocab[year_tok],
    ]

# ─── 3. Output directory -------------------------------------------------
out_dir = "data/tokens/train"

# ─── 4. Iterate files ----------------------------------------------------
if tokenizing:
    rows, _ = grouped_dataset.shape
    for row in tqdm(range(rows)):
        try:
            # 4.0. Get row
            row = grouped_dataset.iloc[row]

            # 4.1. Get MIDI filepath
            midi_path = row["midi_filepath"]

            # 4.2. Get the track ID
            track_id = row["track_id"]

            # 4a. Build CONDITIONING prefix
            genre = row["slugged_genre"]
            artist = row["slugged_artist"]
            year = row["year"]
            prefix_ids = build_prefix(genre, artist, year, tokenizer)          # list[int]

            # 4b. Encode MIDI to tokens
            midi = Score(midi_path)
            midi_tokens = tokenizer(midi)                 # list[int]

            # 4c. Concatenate prefix + midi + <EOS>
            seq_ids = prefix_ids + midi_tokens.ids + [tokenizer.vocab["<EOS>"]]

            # 4d. Save as int32 .npy
            np.save(f"{out_dir}/{track_id}.npy", np.array(seq_ids, dtype=np.int32))
        except Exception as e:
            print(f"Error processing {midi_path}: {e}")
            traceback.print_exc()
            continue

The Model¶

Synphony is a decoder‑only Transformer built in PyTorch 2.2 for autoregressive token prediction:

  1. Embedding – each of the 3 534 tokens is projected to a 768‑dimensional vector.
  2. Relative sinusoidal positional encoding – max sequence length 1024; lets the model extrapolate beyond training lengths.
  3. 8 × TransformerDecoderBlock – every block contains 12‑head self‑attention (64 d per head), residual LayerNorm, a GELU feed‑forward layer, and dropout 0.1.
    • Causal and pad masks are merged into a single FP32 attention mask to avoid memory blow‑ups.
  4. LayerNorm + Linear head – normalise the final hidden state and project back to the vocabulary for next‑token logits.

In shorthand: Embedding → PosEnc → 8 × DecoderBlock → LayerNorm → Linear, giving the model enough depth and width to capture harmonic and rhythmic structure while remaining trainable on Apple‑Silicon hardware.

In [92]:
class RelativePositionalEncoding(nn.Module):
    """
    Sinusoidal *relative‑style* positional encoding.
    The tensor it returns has the same shape as `x`
    so you can just add it:  x + pos(x)

    Args
    ----
    d_model : int            # embedding size
    max_len : int, optional  # maximum sequence length
    """
    def __init__(self, d_model: int, max_len: int = 2048):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len

        # Create the (max_len, d_model) sinusoid table once
        position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float)
            * -(math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, d_model)          # (L, D)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Register as a buffer so it moves with .to(device)
        self.register_buffer("pe", pe)              # (L, D)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : Tensor, shape (batch, seq_len, d_model)

        Returns
        -------
        pos : Tensor, same shape as `x`
        """
        seq_len = x.size(1)
        if seq_len > self.max_len:
            raise ValueError(f"Sequence length {seq_len} exceeds max_len {self.max_len}")
        # (1, L, D) – broadcast over batch dimension
        return self.pe[:seq_len].unsqueeze(0)
In [93]:
class TransformerDecoderBlock(nn.Module):
    """
    Decoder block that merges causal + pad masking into a (B×H, L, L) float mask,
    so no hidden bool→float blow-ups occur.
    """

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        max_len: int = 2048,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(
            embed_dim   = d_model,
            num_heads   = n_heads,
            dropout     = dropout,
            batch_first = True,
        )
        self.ln1      = nn.LayerNorm(d_model)
        self.ln2      = nn.LayerNorm(d_model)
        self.ff       = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
        )
        self.dropout  = nn.Dropout(dropout)

        # Precompute float causal mask: 0 on/under diag, -inf above
        causal = torch.triu(
            torch.full((max_len, max_len), float("-inf")),
            diagonal=1
        )
        self.register_buffer("causal_mask", causal, persistent=False)

    def forward(
        self,
        x: torch.Tensor,            # (B, L, D)
        pad_mask: torch.Tensor=None  # (B, L), True=keep token, False=pad
    ) -> torch.Tensor:
        B, L, _ = x.shape
        H       = self.self_attn.num_heads
        device  = x.device
        dtype   = x.dtype

        # 1) slice the (L×L) causal mask
        causal = self.causal_mask[:L, :L]              # float32, (L, L)

        # 2) build a (B, L) float pad mask: 0 on tokens, -inf on pads
        if pad_mask is not None:
            pad_float = torch.zeros((B, L), device=device, dtype=dtype)
            pad_float = pad_float.masked_fill(~pad_mask, float("-inf"))
            # 3) expand pad_float to (B, L, L) and add causal
            #    pad_float.unsqueeze(1): (B, 1, L) → broadcast over src_len
            attn_batch = causal.unsqueeze(0) + pad_float.unsqueeze(1)  # (B, L, L)
        else:
            attn_batch = causal                               # (L, L)

        # 4) if we have a batch, repeat per-head to (B×H, L, L)
        if pad_mask is not None:
            # attn_batch: (B, L, L) → repeat each batch H times
            attn_mask = attn_batch.repeat_interleave(H, dim=0)  # (B*H, L, L)
        else:
            attn_mask = attn_batch   # 2D mask

        # 5) self-attention with ONLY attn_mask
        attn_out, _ = self.self_attn(
            x, x, x,
            attn_mask=attn_mask
        )

        # 6) residual + norm + feed-forward + norm
        x = self.ln1(x + self.dropout(attn_out))
        x = self.ln2(x + self.dropout(self.ff(x)))
        return x
In [94]:
class Synphony(nn.Module):
    def __init__(self, vocab_size, d_model=512, n_layers=6, n_heads=8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = RelativePositionalEncoding(d_model, max_len=2048)
        self.blocks = nn.ModuleList([
            TransformerDecoderBlock(d_model, n_heads) for _ in range(n_layers)
        ])
        self.ln = nn.LayerNorm(d_model)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, x, pad_mask=None):
        x = self.embed(x) + self.pos(x)
        for blk in self.blocks:
            x = blk(x, pad_mask)
        x = self.ln(x)
        return self.out(x)

The Training Loop¶

  • Perplexity (PPL) as primary intrinsic metric.

    Lower PPL ≈ model is less “surprised” by true sequences, correlating with better musical coherence.

    $PPL = exp{(\text{avg cross entropy} = \frac{\text{running loss}}{\text{train loader size}})}$

  • Training details

    • Machine: 1 NVIDIA L4 GPU → g2-standard-8 (8 vCPUs, 32 GB Memory)
    • MAX_TOKENS = 1024, BATCH_SIZE = 8 → 50 epochs (≈ 7 h).
  • Vocabulary – 3,534 tokens, incl. 125 special conditioning IDs.

  • Training objective

    Classic language‑model framing turns music generation into a well‑studied optimisation problem.

    • Next‑token prediction (teacher‑forcing):
      • loss = cross_entropy(logits, target, ignore_index=PAD_ID, label_smoothing=0.1).
      • Padded positions are masked; causal + pad masks are merged to keep attention logits float32‐sized.
  • Optimisation regimen

    • AdamW

      • Learning Rate = 3 e‑4
      • weight‑decay = 1 e‑2
        • A smaller weight-decay was applied for regularisation purposes. The small decay term (1 × 10⁻²) discourages the weights from growing too large, helping generalisation.
    • Label Smoothing = 0.1

      Instead of treating the target token as probability 1.0, we soften it to 0.9 and spread 0.1 across the rest of the vocabulary. This prevents the model from becoming over‑confident and generally speeds convergence.

    • Gradient clipping (‖g‖₂ ≤ 1).

      Keeps exploding gradients in check by scaling the entire gradient vector to length 1 when it gets too large. That stabilises training, especially with long sequences.

    • LR Scheduler → ReduceLROnPlateau

      Watches the validation loss; if it hasn’t improved for 2 epochs, the scheduler cuts the current learning‑rate in half. That lets us start fast (3 e‑4) and automatically slow down when improvements plateau.

      • factor = 0.5
      • patience = 2
      • floor = 1 e‑6
    • 50 epochs, batch 8, max 1024 tokens.

In [95]:
from torch.utils.data import Dataset, DataLoader

import random
random.seed(42)  # For reproducibility
In [96]:
tok_paths = []

for dirpath, dirnames, filenames in os.walk('data/tokens/train'):
    for file in filenames:
        full_path = os.path.join(dirpath, file)
        if full_path.endswith('.npy'):
            tok_paths.append(full_path)
In [97]:
len(tok_paths)
Out[97]:
6150
In [98]:
split_index = int(len(tok_paths) * 0.8)  # 80% train, 20% test
random.shuffle(tok_paths)

train_paths = tok_paths[:split_index]
test_paths = tok_paths[split_index:]
In [104]:
# ─── 1. Dataset + collate ────────────────────────────────────────────────
class MidiTokenDataset(Dataset):
    def __init__(self, npy_paths):
        self.paths = npy_paths

    def __len__(self):               # number of songs in split
        return len(self.paths)

    def __getitem__(self, idx):      # returns 1‑D np.ndarray[int]
        return np.load(self.paths[idx]).astype(np.int64)

def collate_fn(batch, pad_id):
    B, L = len(batch), MAX_TOKENS
    x = torch.full((B, L), pad_id, dtype=torch.long)
    for i, seq in enumerate(batch):
        seq = torch.from_numpy(seq)
        if seq.numel() > L:
            start = torch.randint(0, seq.numel() - L + 1, (1,)).item()
            seq = seq[start : start + L]
        x[i, : seq.numel()] = seq
    pad_mask = ~x.eq(pad_id)
    return x, pad_mask


# ─── 2. DataLoaders ──────────────────────────────────────────────────────
PAD_ID = tokenizer.vocab['<PAD>']          # or use the ID you chose for <PAD>

train_ds = MidiTokenDataset(train_paths)
val_ds   = MidiTokenDataset(test_paths)

train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    collate_fn=lambda b: collate_fn(b, PAD_ID)
)
val_loader   = DataLoader(
    val_ds,   batch_size=BATCH_SIZE, shuffle=False,
    collate_fn=lambda b: collate_fn(b, PAD_ID)
)

# ─── 3. Model, optimiser, scheduler ─────────────────────────────────────
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

model = Synphony(
    vocab_size=len(tokenizer), d_model=D_MODEL,
    n_layers=N_LAYERS, n_heads=N_HEADS).to(device)

# 1. Switch to AdamW with weight decay
optim = torch.optim.AdamW(model.parameters(),
                          lr=3e-4,           # whatever your current LR is
                          weight_decay=1e-2) # small wd to regularize

# 2. Set up a Reduce-on-Plateau scheduler
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
                                                  mode='min',        # val loss should go down
                                                  factor=0.5,        # cut LR in half
                                                  patience=2,        # wait 2 epochs
                                                  min_lr=1e-6,       # floor on LR
                                                  verbose=True)


# ─── 4. Training loop ────────────────────────────────────────────────────
best_val_loss = float("inf")

for epoch in tqdm(range(1, 51)):                         # 50 epochs
    # ---- train ----------------------------------------------------------
    model.train()
    running_loss = 0.0

    for x, pad_mask in train_loader:          # pad_mask: (B, L)
        x, pad_mask = x.to(device), pad_mask.to(device)

        logits = model(x[:, :-1], pad_mask=pad_mask[:, :-1])

        loss   = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            x[:, 1:].reshape(-1),
            ignore_index=PAD_ID,
            label_smoothing=0.1
        )

        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optim.step()

        running_loss += loss.item()

    train_ppl = math.exp(running_loss / len(train_loader))

    # ---- validation -----------------------------------------------------
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for x, pad_mask in val_loader:             # pad_mask is (B, L)
            x, pad_mask = x.to(device), pad_mask.to(device)

            # exactly like in training
            logits  = model(x[:, :-1], pad_mask=pad_mask[:, :-1])
            val_loss += F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),
                x[:, 1:].reshape(-1),
                ignore_index=PAD_ID
            ).item()

    val_ppl = math.exp(val_loss / len(val_loader))
    print(f"val PPL {val_ppl:6.2f}")
    print(f"Epoch {epoch:02d} ▸ train PPL {train_ppl:6.2f} | val PPL {val_ppl:6.2f}")
    
    # ---- scheduler step -----------------------------------------------
    sched.step(val_loss / len(val_loader))  # pass your avg val_loss
    
    # log current LR
    current_lr = optim.param_groups[0]['lr']
    print(f"         ↳ LR now = {current_lr:.2e}")

    # ---- checkpoint -----------------------------------------------------
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "synphony_best.pt")
        print("  ✓ new best model saved")

print("Done!")
  0%|          | 0/50 [00:00<?, ?it/s]
val PPL  12.04
Epoch 01 ▸ train PPL  41.38 | val PPL  12.04
         ↳ LR now = 3.00e-04
  2%|▏         | 1/50 [08:14<6:43:27, 494.03s/it]
  ✓ new best model saved
val PPL   6.79
Epoch 02 ▸ train PPL  21.59 | val PPL   6.79
         ↳ LR now = 3.00e-04
  4%|▍         | 2/50 [16:31<6:36:43, 495.90s/it]
  ✓ new best model saved
val PPL   4.01
Epoch 03 ▸ train PPL  12.46 | val PPL   4.01
         ↳ LR now = 3.00e-04
  6%|▌         | 3/50 [24:48<6:28:57, 496.54s/it]
  ✓ new best model saved
val PPL   3.48
Epoch 04 ▸ train PPL   9.59 | val PPL   3.48
         ↳ LR now = 3.00e-04
  8%|▊         | 4/50 [33:05<6:20:55, 496.85s/it]
  ✓ new best model saved
val PPL   3.28
Epoch 05 ▸ train PPL   8.84 | val PPL   3.28
         ↳ LR now = 3.00e-04
 10%|█         | 5/50 [41:23<6:12:51, 497.16s/it]
  ✓ new best model saved
val PPL   3.17
Epoch 06 ▸ train PPL   8.44 | val PPL   3.17
         ↳ LR now = 3.00e-04
 12%|█▏        | 6/50 [49:40<6:04:39, 497.25s/it]
  ✓ new best model saved
val PPL   3.06
Epoch 07 ▸ train PPL   8.14 | val PPL   3.06
         ↳ LR now = 3.00e-04
 14%|█▍        | 7/50 [57:58<5:56:26, 497.36s/it]
  ✓ new best model saved
val PPL   3.03
Epoch 08 ▸ train PPL   7.97 | val PPL   3.03
         ↳ LR now = 3.00e-04
 16%|█▌        | 8/50 [1:06:15<5:48:09, 497.36s/it]
  ✓ new best model saved
val PPL   2.97
Epoch 09 ▸ train PPL   7.81 | val PPL   2.97
         ↳ LR now = 3.00e-04
 18%|█▊        | 9/50 [1:14:33<5:39:53, 497.39s/it]
  ✓ new best model saved
val PPL   2.92
Epoch 10 ▸ train PPL   7.64 | val PPL   2.92
         ↳ LR now = 3.00e-04
 20%|██        | 10/50 [1:22:50<5:31:34, 497.37s/it]
  ✓ new best model saved
val PPL   2.87
Epoch 11 ▸ train PPL   7.55 | val PPL   2.87
         ↳ LR now = 3.00e-04
 22%|██▏       | 11/50 [1:31:08<5:23:17, 497.38s/it]
  ✓ new best model saved
 24%|██▍       | 12/50 [1:39:24<5:14:51, 497.14s/it]
val PPL   2.87
Epoch 12 ▸ train PPL   7.42 | val PPL   2.87
         ↳ LR now = 3.00e-04
val PPL   2.82
Epoch 13 ▸ train PPL   7.35 | val PPL   2.82
         ↳ LR now = 3.00e-04
 26%|██▌       | 13/50 [1:47:42<5:06:43, 497.40s/it]
  ✓ new best model saved
val PPL   2.80
Epoch 14 ▸ train PPL   7.28 | val PPL   2.80
         ↳ LR now = 3.00e-04
 28%|██▊       | 14/50 [1:56:00<4:58:31, 497.54s/it]
  ✓ new best model saved
val PPL   2.77
Epoch 15 ▸ train PPL   7.23 | val PPL   2.77
         ↳ LR now = 3.00e-04
 30%|███       | 15/50 [2:04:18<4:50:18, 497.66s/it]
  ✓ new best model saved
val PPL   2.76
Epoch 16 ▸ train PPL   7.15 | val PPL   2.76
         ↳ LR now = 3.00e-04
 32%|███▏      | 16/50 [2:12:36<4:42:01, 497.70s/it]
  ✓ new best model saved
val PPL   2.74
Epoch 17 ▸ train PPL   7.10 | val PPL   2.74
         ↳ LR now = 3.00e-04
 34%|███▍      | 17/50 [2:20:54<4:33:44, 497.73s/it]
  ✓ new best model saved
val PPL   2.71
Epoch 18 ▸ train PPL   7.05 | val PPL   2.71
         ↳ LR now = 3.00e-04
 36%|███▌      | 18/50 [2:29:11<4:25:23, 497.62s/it]
  ✓ new best model saved
val PPL   2.67
Epoch 19 ▸ train PPL   7.00 | val PPL   2.67
         ↳ LR now = 3.00e-04
 38%|███▊      | 19/50 [2:37:29<4:17:08, 497.69s/it]
  ✓ new best model saved
 40%|████      | 20/50 [2:45:45<4:08:35, 497.17s/it]
val PPL   2.70
Epoch 20 ▸ train PPL   6.96 | val PPL   2.70
         ↳ LR now = 3.00e-04
 42%|████▏     | 21/50 [2:54:01<4:00:11, 496.95s/it]
val PPL   2.68
Epoch 21 ▸ train PPL   6.92 | val PPL   2.68
         ↳ LR now = 3.00e-04
val PPL   2.66
Epoch 22 ▸ train PPL   6.88 | val PPL   2.66
         ↳ LR now = 3.00e-04
 44%|████▍     | 22/50 [3:02:19<3:52:00, 497.15s/it]
  ✓ new best model saved
 46%|████▌     | 23/50 [3:10:35<3:43:35, 496.88s/it]
val PPL   2.66
Epoch 23 ▸ train PPL   6.82 | val PPL   2.66
         ↳ LR now = 3.00e-04
val PPL   2.63
Epoch 24 ▸ train PPL   6.80 | val PPL   2.63
         ↳ LR now = 3.00e-04
 48%|████▊     | 24/50 [3:18:53<3:35:24, 497.11s/it]
  ✓ new best model saved
 50%|█████     | 25/50 [3:27:09<3:26:59, 496.79s/it]
val PPL   2.65
Epoch 25 ▸ train PPL   6.79 | val PPL   2.65
         ↳ LR now = 3.00e-04
val PPL   2.62
Epoch 26 ▸ train PPL   6.75 | val PPL   2.62
         ↳ LR now = 3.00e-04
 52%|█████▏    | 26/50 [3:35:26<3:18:49, 497.05s/it]
  ✓ new best model saved
 54%|█████▍    | 27/50 [3:43:42<3:10:24, 496.72s/it]
val PPL   2.63
Epoch 27 ▸ train PPL   6.73 | val PPL   2.63
         ↳ LR now = 3.00e-04
val PPL   2.61
Epoch 28 ▸ train PPL   6.68 | val PPL   2.61
         ↳ LR now = 3.00e-04
 56%|█████▌    | 28/50 [3:52:00<3:02:12, 496.93s/it]
  ✓ new best model saved
 58%|█████▊    | 29/50 [4:00:16<2:53:51, 496.73s/it]
val PPL   2.61
Epoch 29 ▸ train PPL   6.67 | val PPL   2.61
         ↳ LR now = 3.00e-04
val PPL   2.60
Epoch 30 ▸ train PPL   6.64 | val PPL   2.60
         ↳ LR now = 3.00e-04
 60%|██████    | 30/50 [4:08:34<2:45:42, 497.11s/it]
  ✓ new best model saved
val PPL   2.57
Epoch 31 ▸ train PPL   6.65 | val PPL   2.57
         ↳ LR now = 3.00e-04
 62%|██████▏   | 31/50 [4:16:52<2:37:28, 497.31s/it]
  ✓ new best model saved
 64%|██████▍   | 32/50 [4:25:08<2:29:05, 496.96s/it]
val PPL   2.59
Epoch 32 ▸ train PPL   6.59 | val PPL   2.59
         ↳ LR now = 3.00e-04
 66%|██████▌   | 33/50 [4:33:24<2:20:44, 496.72s/it]
val PPL   2.58
Epoch 33 ▸ train PPL   6.57 | val PPL   2.58
         ↳ LR now = 3.00e-04
val PPL   2.55
Epoch 34 ▸ train PPL   6.56 | val PPL   2.55
         ↳ LR now = 3.00e-04
 68%|██████▊   | 34/50 [4:41:42<2:12:32, 497.02s/it]
  ✓ new best model saved
 70%|███████   | 35/50 [4:49:58<2:04:11, 496.74s/it]
val PPL   2.57
Epoch 35 ▸ train PPL   6.53 | val PPL   2.57
         ↳ LR now = 3.00e-04
 72%|███████▏  | 36/50 [4:58:14<1:55:51, 496.52s/it]
val PPL   2.57
Epoch 36 ▸ train PPL   6.53 | val PPL   2.57
         ↳ LR now = 3.00e-04
val PPL   2.55
Epoch 37 ▸ train PPL   6.53 | val PPL   2.55
         ↳ LR now = 3.00e-04
 74%|███████▍  | 37/50 [5:06:31<1:47:38, 496.79s/it]
  ✓ new best model saved
val PPL   2.54
Epoch 38 ▸ train PPL   6.49 | val PPL   2.54
         ↳ LR now = 3.00e-04
 76%|███████▌  | 38/50 [5:14:49<1:39:24, 497.00s/it]
  ✓ new best model saved
 78%|███████▊  | 39/50 [5:23:05<1:31:03, 496.68s/it]
val PPL   2.55
Epoch 39 ▸ train PPL   6.45 | val PPL   2.55
         ↳ LR now = 3.00e-04
 80%|████████  | 40/50 [5:31:21<1:22:44, 496.46s/it]
val PPL   2.55
Epoch 40 ▸ train PPL   6.45 | val PPL   2.55
         ↳ LR now = 3.00e-04
val PPL   2.54
Epoch 41 ▸ train PPL   6.44 | val PPL   2.54
         ↳ LR now = 3.00e-04
 82%|████████▏ | 41/50 [5:39:38<1:14:30, 496.69s/it]
  ✓ new best model saved
val PPL   2.52
Epoch 42 ▸ train PPL   6.40 | val PPL   2.52
         ↳ LR now = 3.00e-04
 84%|████████▍ | 42/50 [5:47:55<1:06:15, 496.91s/it]
  ✓ new best model saved
 86%|████████▌ | 43/50 [5:56:11<57:56, 496.60s/it]  
val PPL   2.52
Epoch 43 ▸ train PPL   6.41 | val PPL   2.52
         ↳ LR now = 3.00e-04
 88%|████████▊ | 44/50 [6:04:27<49:38, 496.37s/it]
val PPL   2.53
Epoch 44 ▸ train PPL   6.38 | val PPL   2.53
         ↳ LR now = 3.00e-04
val PPL   2.52
Epoch 45 ▸ train PPL   6.38 | val PPL   2.52
         ↳ LR now = 1.50e-04
 90%|█████████ | 45/50 [6:12:45<41:23, 496.68s/it]
  ✓ new best model saved
val PPL   2.46
Epoch 46 ▸ train PPL   6.20 | val PPL   2.46
         ↳ LR now = 1.50e-04
 92%|█████████▏| 46/50 [6:21:02<33:07, 496.86s/it]
  ✓ new best model saved
val PPL   2.45
Epoch 47 ▸ train PPL   6.12 | val PPL   2.45
         ↳ LR now = 1.50e-04
 94%|█████████▍| 47/50 [6:29:19<24:50, 496.93s/it]
  ✓ new best model saved
 96%|█████████▌| 48/50 [6:37:35<16:33, 496.73s/it]
val PPL   2.46
Epoch 48 ▸ train PPL   6.09 | val PPL   2.46
         ↳ LR now = 1.50e-04
 98%|█████████▊| 49/50 [6:45:51<08:16, 496.51s/it]
val PPL   2.45
Epoch 49 ▸ train PPL   6.06 | val PPL   2.45
         ↳ LR now = 1.50e-04
val PPL   2.43
Epoch 50 ▸ train PPL   6.04 | val PPL   2.43
         ↳ LR now = 1.50e-04
100%|██████████| 50/50 [6:54:09<00:00, 496.99s/it]
  ✓ new best model saved
Done!

In [72]:
tokenizer.vocab_size
Out[72]:
3534

3. Model Inference¶

In [105]:
model.eval()
Out[105]:
Synphony(
  (embed): Embedding(3534, 768)
  (pos): RelativePositionalEncoding()
  (blocks): ModuleList(
    (0-7): 8 x TransformerDecoderBlock(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=3072, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (out): Linear(in_features=768, out_features=3534, bias=True)
)
In [106]:
TEMPERATURE = 1.0
TOP_K = 8

# ─── 2. Helper for top-k filtering ───────────────────────────────────────
def top_k_logits(logits, k):
    v, _ = torch.topk(logits, k)
    threshold = v[-1]
    return torch.where(logits < threshold, torch.full_like(logits, -float("Inf")), logits)

# ─── 3. Autoregressive generation ────────────────────────────────────────
@torch.no_grad()
def generate(
        genre:str,
        artist:str,
        year:int,
        max_length:int = MAX_TOKENS
    ) -> list[int]:
    prefix = build_prefix(genre, artist, year, tokenizer)
    input_ids = torch.tensor([prefix], device=device)  # (1, P)
    pad_mask  = torch.ones_like(input_ids, dtype=torch.bool, device=device)

    for _ in tqdm(range(max_length - len(prefix))):
        logits = model(input_ids, pad_mask=pad_mask)
        next_logits = logits[0, -1, :]                  # (V,)
        next_logits = next_logits / TEMPERATURE
        next_logits = top_k_logits(next_logits, TOP_K)
        probs = F.softmax(next_logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)  # (1,)
        if next_id.item() == tokenizer.vocab["<EOS>"]:
            break

        # append and extend pad_mask
        input_ids = torch.cat([input_ids, next_id.unsqueeze(0)], dim=1)   # (1, L+1)
        pad_mask  = torch.ones_like(input_ids, dtype=torch.bool, device=device)

    return input_ids[0].tolist()

# ─── 4. Decode to MIDI & save ────────────────────────────────────────────
def tokens_to_midi(token_ids: list[int], out_path: str):
    """
    Drop the 3 metadata tokens + optional EOS, then decode the rest.
    """
    # 1) drop the first 3 prefix IDs (genre, artist, year)
    musical_ids = token_ids[3:]
    # 2) drop trailing <EOS> if present
    eos_id = tokenizer.vocab["<EOS>"]
    if len(musical_ids) > 0 and musical_ids[-1] == eos_id:
        musical_ids = musical_ids[:-1]

    # 3) decode only the musical tokens back to a PrettyMIDI
    pm = tokenizer(musical_ids)
    # 4) write out the .mid file
    pm.dump_midi(out_path)
In [120]:
# ─── 5. Run it! ───────────────────────────────────────────────────────────
# Example user inputs
genre_input  = "ROCK"
artist_input = "GLORIA_GAYNOR"
year_input   = 1990

gen_ids = generate(genre_input, artist_input, year_input, max_length=512)
out_file = "generated.mid"
tokens_to_midi(gen_ids, out_file)
print(f"🎹 Wrote MIDI to {out_file}")
100%|██████████| 509/509 [00:03<00:00, 131.18it/s]
🎹 Wrote MIDI to generated.mid

In [121]:
from midi2audio import FluidSynth
from IPython.display import Audio

# render your MIDI to a WAV
fs = FluidSynth()
fs.midi_to_audio('generated.mid', 'generated.wav')

# now embed the WAV inline
Audio('generated.wav')
Parameter '/home/jupyter/.fluidsynth/default_sound_font.sf2' not a SoundFont or MIDI file or error occurred identifying it.
FluidSynth runtime version 2.1.7
Copyright (C) 2000-2021 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of E-mu Systems, Inc.

Rendering audio to file 'generated.wav'..
Out[121]:
Your browser does not support the audio element.
In [ ]: